{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d124d22-de73-436b-86cd-9b162b469be8",
   "metadata": {
    "id": "2d124d22-de73-436b-86cd-9b162b469be8"
   },
   "outputs": [],
   "source": [
    "%pip install langchain_community\n",
    "%pip install langchain_experimental\n",
    "%pip install langchain-openai\n",
    "%pip install langchainhub\n",
    "%pip install chromadb\n",
    "%pip install langchain\n",
    "%pip install python-dotenv\n",
    "%pip install PyPDF2 -q --user\n",
    "%pip install rank_bm25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f884314f-870c-4bfb-b6c1-a5b4801ec172",
   "metadata": {
    "executionInfo": {
     "elapsed": 7390,
     "status": "ok",
     "timestamp": 1715455737041,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "f884314f-870c-4bfb-b6c1-a5b4801ec172"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import openai\n",
    "from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
    "from langchain import hub\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "import chromadb\n",
    "from langchain_community.vectorstores import Chroma\n",
    "from langchain_core.runnables import RunnableParallel\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from PyPDF2 import PdfReader\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.docstore.document import Document\n",
    "from langchain_community.retrievers import BM25Retriever\n",
    "\n",
    "# new\n",
    "from langchain.retrievers import EnsembleRetriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eba3468a-d7c2-4a79-8df2-c335542950f2",
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1715455737042,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "eba3468a-d7c2-4a79-8df2-c335542950f2"
   },
   "outputs": [],
   "source": [
    "# variables\n",
    "_ = load_dotenv(dotenv_path='env.txt')\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')\n",
    "openai.api_key = os.environ['OPENAI_API_KEY']\n",
    "embedding_function = OpenAIEmbeddings()\n",
    "llm = ChatOpenAI(model_name=\"gpt-4o-mini\", temperature=0)\n",
    "pdf_path = \"google-2023-environmental-report.pdf\"\n",
    "collection_name = \"google_environmental_report\"\n",
    "str_output_parser = StrOutputParser()\n",
    "user_query = \"What are Google's environmental initiatives?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d3ad428a-3eb6-40ec-a1a5-62565ead1e5b",
   "metadata": {
    "id": "d3ad428a-3eb6-40ec-a1a5-62565ead1e5b"
   },
   "outputs": [],
   "source": [
    "#### INDEXING ####"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "98ccda2c-0f4c-41c5-804d-2227cdf35aa7",
   "metadata": {
    "executionInfo": {
     "elapsed": 19512,
     "status": "ok",
     "timestamp": 1715455756551,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "98ccda2c-0f4c-41c5-804d-2227cdf35aa7"
   },
   "outputs": [],
   "source": [
    "# Load the PDF and extract text\n",
    "pdf_reader = PdfReader(pdf_path)\n",
    "text = \"\"\n",
    "for page in pdf_reader.pages:\n",
    "    text += page.extract_text()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "927a4c65-aa05-486c-8295-2f99673e7c20",
   "metadata": {
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1715455756552,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "927a4c65-aa05-486c-8295-2f99673e7c20"
   },
   "outputs": [],
   "source": [
    "# Split\n",
    "character_splitter = RecursiveCharacterTextSplitter(\n",
    "    separators=[\"\\n\\n\", \"\\n\", \". \", \" \", \"\"],\n",
    "    chunk_size=1000,\n",
    "    chunk_overlap=200\n",
    ")\n",
    "splits = character_splitter.split_text(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c4ccbc8-256f-4152-8ef0-684ef119245e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dense_documents = [Document(page_content=text, metadata={\"id\": str(i), \"source\": \"dense\"}) for i, text in enumerate(splits)]\n",
    "sparse_documents = [Document(page_content=text, metadata={\"id\": str(i), \"source\": \"sparse\"}) for i, text in enumerate(splits)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b13568c-d633-464d-8c43-0d55f34cc8c1",
   "metadata": {
    "executionInfo": {
     "elapsed": 9471,
     "status": "ok",
     "timestamp": 1715455766015,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "6b13568c-d633-464d-8c43-0d55f34cc8c1"
   },
   "outputs": [],
   "source": [
    "chroma_client = chromadb.Client()\n",
    "vectorstore = Chroma.from_documents(\n",
    "    documents=dense_documents,\n",
    "    embedding=embedding_function,\n",
    "    collection_name=collection_name,\n",
    "    client=chroma_client\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "515f7ec0-ee09-4a03-a5be-4d8661a8121c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dense retriever\n",
    "dense_retriever = vectorstore.as_retriever(search_kwargs={\"k\": 10})\n",
    "# Create sparse retriever\n",
    "sparse_retriever = BM25Retriever.from_documents(sparse_documents, k=10)\n",
    "# initialize the ensemble retriever\n",
    "ensemble_retriever = EnsembleRetriever(retrievers=[dense_retriever, sparse_retriever], weights=[0.5, 0.5], c=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce8df01-925b-45b5-8fb8-17b5c40c581f",
   "metadata": {
    "id": "6ce8df01-925b-45b5-8fb8-17b5c40c581f"
   },
   "outputs": [],
   "source": [
    "#### RETRIEVAL and GENERATION ####"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fac053d8-b871-4b50-b04e-28dec9fb3b0f",
   "metadata": {
    "executionInfo": {
     "elapsed": 355,
     "status": "ok",
     "timestamp": 1715455766356,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "fac053d8-b871-4b50-b04e-28dec9fb3b0f"
   },
   "outputs": [],
   "source": [
    "# Prompt\n",
    "prompt = hub.pull(\"jclemens24/rag-prompt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ef30632-13dd-4a34-af33-cb8fab94f169",
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1715455766356,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "5ef30632-13dd-4a34-af33-cb8fab94f169"
   },
   "outputs": [],
   "source": [
    "# Relevance check prompt\n",
    "relevance_prompt_template = PromptTemplate.from_template(\n",
    "    \"\"\"\n",
    "    Given the following question and retrieved context, determine if the context is relevant to the question.\n",
    "    Provide a score from 1 to 5, where 1 is not at all relevant and 5 is highly relevant.\n",
    "    Return ONLY the numeric score, without any additional text or explanation.\n",
    "\n",
    "    Question: {question}\n",
    "    Retrieved Context: {retrieved_context}\n",
    "\n",
    "    Relevance Score:\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8975479-b3e3-481d-ad7b-08b4eb3faaef",
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1715455766356,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "e8975479-b3e3-481d-ad7b-08b4eb3faaef"
   },
   "outputs": [],
   "source": [
    "# Post-processing\n",
    "def format_docs(docs):\n",
    "    return \"\\n\\n\".join(doc.page_content for doc in docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deb6d70c-42ef-4bda-9607-48f02c941280",
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1715455766356,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "deb6d70c-42ef-4bda-9607-48f02c941280"
   },
   "outputs": [],
   "source": [
    "# LLM\n",
    "llm = ChatOpenAI(model_name=\"gpt-4o\", temperature=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd9db713-f705-4b65-800e-2c4e3d0e4ef4",
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1715455766512,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "fd9db713-f705-4b65-800e-2c4e3d0e4ef4"
   },
   "outputs": [],
   "source": [
    "def extract_score(llm_output):\n",
    "    try:\n",
    "        score = float(llm_output.strip())\n",
    "        return score\n",
    "    except ValueError:\n",
    "        return 0\n",
    "\n",
    "# Chain it all together with LangChain\n",
    "def conditional_answer(x):\n",
    "    relevance_score = extract_score(x['relevance_score'])\n",
    "    if relevance_score < 4:\n",
    "        return \"I don't know.\"\n",
    "    else:\n",
    "        return x['answer']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09213587-e1c6-423e-92df-71d2e008992f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rag_chain_from_docs = (\n",
    "    RunnablePassthrough.assign(context=(lambda x: format_docs(x[\"context\"])))\n",
    "    | RunnableParallel(\n",
    "        {\"relevance_score\": (\n",
    "            RunnablePassthrough()\n",
    "            | (lambda x: relevance_prompt_template.format(question=x['question'], retrieved_context=x['context']))\n",
    "            | llm\n",
    "            | str_output_parser\n",
    "        ), \"answer\": (\n",
    "            RunnablePassthrough()\n",
    "            | prompt\n",
    "            | llm\n",
    "            | str_output_parser\n",
    "        )}\n",
    "    )\n",
    "    | RunnablePassthrough().assign(final_answer=conditional_answer)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc5c2ab0-9191-40f7-abf2-681f1c751429",
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1715455766512,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "dc5c2ab0-9191-40f7-abf2-681f1c751429"
   },
   "outputs": [],
   "source": [
    "rag_chain_with_source = RunnableParallel(\n",
    "    {\"context\": ensemble_retriever, \"question\": RunnablePassthrough()}\n",
    ").assign(answer=rag_chain_from_docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b30177a-f9ab-45e4-812d-33b0f97325bd",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3524,
     "status": "ok",
     "timestamp": 1715455770031,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 240
    },
    "id": "8b30177a-f9ab-45e4-812d-33b0f97325bd",
    "outputId": "fb5f615f-202e-4409-f4c7-10aac3a577ee",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# User Query\n",
    "result = rag_chain_with_source.invoke(user_query)\n",
    "retrieved_docs = result['context']\n",
    "\n",
    "print(f\"Original Question: {user_query}\\n\")\n",
    "print(f\"Relevance Score: {result['answer']['relevance_score']}\\n\")\n",
    "print(f\"Final Answer:\\n{result['answer']['final_answer']}\\n\\n\")\n",
    "print(\"Retrieved Documents:\")\n",
    "for i, doc in enumerate(retrieved_docs, start=1):\n",
    "    print(f\"Document {i}: Document ID: {doc.metadata['id']} source: {doc.metadata['source']}\")\n",
    "    print(f\"Content:\\n{doc.page_content}\\n\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "CHAPTER8-2_HYBRID-ENSEMBLE.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
